import matplotlib
import matplotlib.pyplot as plt
import os
import sys
import random as rnd
import sigpy.mri as mr
import torch
import numpy as np
from graphviz import Digraph
from torch.autograd import Variable
from torch.autograd import Variable
import hdf5storage

from random import randint
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

## Path for folder containing demo dataset and trained model
# change path accordingly
file_Initial = "/backup/HassanData/DL_dealiasing/"

## Defind network for real-time complex-difference reconstruction
Hidden_layer = 64
Conv_kernel = 3
Conv_kerenl_time = 3
Padd_space = 1
Padd_time = 1
drop_out_level = 0.15
Bias = True
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        self.conv_DL1 = torch.nn.Sequential()
        self.conv_DL1.add_module("Conv_DL1",nn.Conv3d(1,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL1.add_module("BN1_DL1",nn.BatchNorm3d(Hidden_layer))
        self.DropOut1 = nn.Dropout3d(p=drop_out_level,inplace=True)


        self.conv_DL1_v2 = torch.nn.Sequential()
        self.conv_DL1_v2.add_module("Conv_DL1_v2",nn.Conv3d(Hidden_layer,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL1_v2.add_module("BN1_DL1_v2",nn.BatchNorm3d(Hidden_layer))

        self.DropOut2 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_MP1 = torch.nn.Sequential()
        self.conv_MP1.add_module("Max Pool 1",nn.MaxPool3d((2,2,2),stride = (2,2,2)))

        self.conv_DL2 = torch.nn.Sequential()
        self.conv_DL2.add_module("Conv_DL2",nn.Conv3d(Hidden_layer,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL2.add_module("BN1_DL2",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut3 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_DL2_v2 = torch.nn.Sequential()
        self.conv_DL2_v2.add_module("Conv_DL2_v2",nn.Conv3d(Hidden_layer*2,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL2_v2.add_module("BN1_DL2_v2",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut4 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_MP2 = torch.nn.Sequential()
        self.conv_MP2.add_module("Max Pool 2",nn.MaxPool3d((2,2,2),stride = (2,2,2)))

        self.conv_DL3 = torch.nn.Sequential()
        self.conv_DL3.add_module("Conv_DL3",nn.Conv3d(Hidden_layer*2,Hidden_layer*4,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL3.add_module("BN1_DL3",nn.BatchNorm3d(Hidden_layer*4))

        self.DropOut5 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_DL3_v2 = torch.nn.Sequential()
        self.conv_DL3_v2.add_module("Conv_DL3_v2",nn.Conv3d(Hidden_layer*4,Hidden_layer*4,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL3_v2.add_module("BN1_DL3_v2",nn.BatchNorm3d(Hidden_layer*4))

        self.DropOut6 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.convT1 = nn.ConvTranspose3d(Hidden_layer*4,Hidden_layer*2,(2,2,2),stride = (2,2,2))

        self.conv_UP1 = torch.nn.Sequential()
        self.conv_UP1.add_module("Conv_UP1",nn.Conv3d(Hidden_layer*4,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP1.add_module("BN1_UP1",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut7 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_UP1_v2 = torch.nn.Sequential()
        self.conv_UP1_v2.add_module("Conv_UP1_v2",nn.Conv3d(Hidden_layer*2,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP1_v2.add_module("BN1_UP1_v2",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut8 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.convT2 = nn.ConvTranspose3d(Hidden_layer*2,Hidden_layer,(2,2,2),stride = (2,2,2))

        self.conv_UP2 = torch.nn.Sequential()
        self.conv_UP2.add_module("Conv_UP2",nn.Conv3d(Hidden_layer*2,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP2.add_module("BN1_UP2",nn.BatchNorm3d(Hidden_layer))

        self.DropOut9 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_UP2_v2 = torch.nn.Sequential()
        self.conv_UP2_v2.add_module("Conv_UP2_v2",nn.Conv3d(Hidden_layer,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP2_v2.add_module("BN1_UP2_v2",nn.BatchNorm3d(Hidden_layer))

        self.DropOut10 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_final = torch.nn.Sequential()
        self.conv_final.add_module("Conv Final", nn.Conv3d(Hidden_layer,1,(1,1,1),padding = (0,0,0),stride = 1,bias = Bias))



    def forward(self,x):
        x_down1 = F.relu(self.DropOut1(self.conv_DL1.forward(x)))
        x_down1_v2 = F.relu(self.DropOut2(self.conv_DL1_v2.forward(x_down1)))

        x_MaxPool = self.conv_MP1.forward(x_down1_v2)

        x_down2 = F.relu(self.DropOut3(self.conv_DL2.forward(x_MaxPool)))
        x_down2_v2 = F.relu(self.DropOut4(self.conv_DL2_v2.forward(x_down2)))

        x_MaxPool_v2 = self.conv_MP2.forward(x_down2_v2)

        x_down3 = F.relu(self.DropOut5(self.conv_DL3.forward(x_MaxPool_v2)))
        x_down3_v2 = F.relu(self.DropOut6(self.conv_DL3_v2.forward(x_down3)))

        x_up1_ConvT = self.convT1(x_down3_v2,output_size = x_down2_v2.size())
        x_down2_up1_stack = torch.cat((x_down2_v2,x_up1_ConvT),1)

        x_up1 =  F.relu(self.DropOut7(self.conv_UP1.forward(x_down2_up1_stack)))
        x_up1_v2 =  F.relu(self.DropOut8(self.conv_UP1_v2.forward(x_up1)))

        x_up2_ConvT = self.convT2(x_up1_v2,output_size = x_down1_v2.size())
        x_down1_up2_stack = torch.cat((x_down1_v2,x_up2_ConvT),1)

        x_up2 = F.relu(self.DropOut9(self.conv_UP2.forward(x_down1_up2_stack)))
        x_up2_v2 = F.relu(self.DropOut10(self.conv_UP2_v2.forward(x_up2)))

        output = self.conv_final.forward(x_up2_v2)

        return output
net = Net()

## Loading trained 3D U-net for real-time complex-difference reconstruction
PATH = file_Initial + "/Model/Unet3D/"
PATH2 = PATH + "model.py"
netCD = Net()
deviceCD = torch.device("cuda:5")
netCD = nn.DataParallel(netCD, device_ids=[5])
netCD.load_state_dict(torch.load(PATH2))
netCD.to(deviceCD)
netCD.eval()

## Defind network for real-time reference reconstruction
Hidden_layer = 64
Conv_kernel = 3
Conv_kerenl_time = 3
Padd_space = 1
Padd_time = 1
drop_out_level = 0.15
Bias = True
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        self.conv_DL1 = torch.nn.Sequential()
        self.conv_DL1.add_module("Conv_DL1",nn.Conv3d(1,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL1.add_module("BN1_DL1",nn.BatchNorm3d(Hidden_layer))
        self.DropOut1 = nn.Dropout3d(p=drop_out_level,inplace=True)


        self.conv_DL1_v2 = torch.nn.Sequential()
        self.conv_DL1_v2.add_module("Conv_DL1_v2",nn.Conv3d(Hidden_layer,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL1_v2.add_module("BN1_DL1_v2",nn.BatchNorm3d(Hidden_layer))

        self.DropOut2 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_MP1 = torch.nn.Sequential()
        self.conv_MP1.add_module("Max Pool 1",nn.MaxPool3d((2,2,2),stride = (2,2,2)))

        self.conv_DL2 = torch.nn.Sequential()
        self.conv_DL2.add_module("Conv_DL2",nn.Conv3d(Hidden_layer,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL2.add_module("BN1_DL2",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut3 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_DL2_v2 = torch.nn.Sequential()
        self.conv_DL2_v2.add_module("Conv_DL2_v2",nn.Conv3d(Hidden_layer*2,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL2_v2.add_module("BN1_DL2_v2",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut4 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_MP2 = torch.nn.Sequential()
        self.conv_MP2.add_module("Max Pool 2",nn.MaxPool3d((2,2,2),stride = (2,2,2)))

        self.conv_DL3 = torch.nn.Sequential()
        self.conv_DL3.add_module("Conv_DL3",nn.Conv3d(Hidden_layer*2,Hidden_layer*4,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL3.add_module("BN1_DL3",nn.BatchNorm3d(Hidden_layer*4))

        self.DropOut5 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_DL3_v2 = torch.nn.Sequential()
        self.conv_DL3_v2.add_module("Conv_DL3_v2",nn.Conv3d(Hidden_layer*4,Hidden_layer*4,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_DL3_v2.add_module("BN1_DL3_v2",nn.BatchNorm3d(Hidden_layer*4))

        self.DropOut6 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.convT1 = nn.ConvTranspose3d(Hidden_layer*4,Hidden_layer*2,(2,2,2),stride = (2,2,2))

        self.conv_UP1 = torch.nn.Sequential()
        self.conv_UP1.add_module("Conv_UP1",nn.Conv3d(Hidden_layer*4,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP1.add_module("BN1_UP1",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut7 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_UP1_v2 = torch.nn.Sequential()
        self.conv_UP1_v2.add_module("Conv_UP1_v2",nn.Conv3d(Hidden_layer*2,Hidden_layer*2,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP1_v2.add_module("BN1_UP1_v2",nn.BatchNorm3d(Hidden_layer*2))

        self.DropOut8 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.convT2 = nn.ConvTranspose3d(Hidden_layer*2,Hidden_layer,(2,2,2),stride = (2,2,2))

        self.conv_UP2 = torch.nn.Sequential()
        self.conv_UP2.add_module("Conv_UP2",nn.Conv3d(Hidden_layer*2,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP2.add_module("BN1_UP2",nn.BatchNorm3d(Hidden_layer))

        self.DropOut9 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_UP2_v2 = torch.nn.Sequential()
        self.conv_UP2_v2.add_module("Conv_UP2_v2",nn.Conv3d(Hidden_layer,Hidden_layer,(Conv_kernel,Conv_kernel,Conv_kerenl_time), padding = (Padd_space,Padd_space,Padd_time), stride = 1,bias = Bias))
        self.conv_UP2_v2.add_module("BN1_UP2_v2",nn.BatchNorm3d(Hidden_layer))

        self.DropOut10 = nn.Dropout3d(p=drop_out_level,inplace=True)

        self.conv_final = torch.nn.Sequential()
        self.conv_final.add_module("Conv Final", nn.Conv3d(Hidden_layer,1,(1,1,1),padding = (0,0,0),stride = 1,bias = Bias))



    def forward(self,x):
        x_down1 = F.relu(self.DropOut1(self.conv_DL1.forward(x)))
        x_down1_v2 = F.relu(self.DropOut2(self.conv_DL1_v2.forward(x_down1)))

        x_MaxPool = self.conv_MP1.forward(x_down1_v2)

        x_down2 = F.relu(self.DropOut3(self.conv_DL2.forward(x_MaxPool)))
        x_down2_v2 = F.relu(self.DropOut4(self.conv_DL2_v2.forward(x_down2)))

        x_MaxPool_v2 = self.conv_MP2.forward(x_down2_v2)

        x_down3 = F.relu(self.DropOut5(self.conv_DL3.forward(x_MaxPool_v2)))
        x_down3_v2 = F.relu(self.DropOut6(self.conv_DL3_v2.forward(x_down3)))

        x_up1_ConvT = self.convT1(x_down3_v2,output_size = x_down2_v2.size())
        x_down2_up1_stack = torch.cat((x_down2_v2,x_up1_ConvT),1)

        x_up1 =  F.relu(self.DropOut7(self.conv_UP1.forward(x_down2_up1_stack)))
        x_up1_v2 =  F.relu(self.DropOut8(self.conv_UP1_v2.forward(x_up1)))

        x_up2_ConvT = self.convT2(x_up1_v2,output_size = x_down1_v2.size())
        x_down1_up2_stack = torch.cat((x_down1_v2,x_up2_ConvT),1)

        x_up2 = F.relu(self.DropOut9(self.conv_UP2.forward(x_down1_up2_stack)))
        x_up2_v2 = F.relu(self.DropOut10(self.conv_UP2_v2.forward(x_up2)))

        output = x+self.conv_final.forward(x_up2_v2)

        return output
netRef = Net()

## Loading trained residual 3D U-net for real-time reference reconstruction
PATH = file_Initial + "/Model/Unet3DRes/"
PATH2 = PATH + "model.py"
netRef = nn.DataParallel(netRef, device_ids=[5])
netRef.load_state_dict(torch.load(PATH2))
netRef.to(deviceCD)
netRef.eval()

search_path_save_ZF_ref = file_Initial + "Data/NuFFT_ZF_ref/"
search_path_save_ZF_venc = file_Initial + "Data/NuFFT_ZF_venc/"
Crop_nx = 160
normalize_window = 48
nt_use = 120
ListofFiles_zp = os.listdir(search_path_save_ZF_ref)

# loading undersampled real-time PC datasets (i.e. reference and velocity encoding images)
ListofFiles_zp_string = str(ListofFiles_zp[0])
filename_zp = ListofFiles_zp_string
load_path_zp_ref = search_path_save_ZF_ref + filename_zp
load_path_zp_venc = search_path_save_ZF_venc + filename_zp
mat_zp_venc = hdf5storage.loadmat(load_path_zp_venc)
mat_zp_venc = np.complex64(list(mat_zp_venc.values()))
mat_zp_ref = hdf5storage.loadmat(load_path_zp_ref)
mat_zp_ref = np.complex64(list(mat_zp_ref.values()))

# cropping images to 160 x 160 matrix size
startx1 = np.floor(mat_zp_ref.shape[1] / 2 - Crop_nx / 2).astype(int)
endx1 = np.floor(mat_zp_ref.shape[1] / 2 + Crop_nx / 2).astype(int)
starty1 = np.floor(mat_zp_ref.shape[2] / 2 - Crop_nx / 2).astype(int)
endy1 = np.floor(mat_zp_ref.shape[2] / 2 + Crop_nx / 2).astype(int)
mat_zp_ref = mat_zp_ref[:, startx1:endx1, starty1:endy1, 1:nt_use]
mat_zp_venc = mat_zp_venc[:, startx1:endx1, starty1:endy1, 1:nt_use]


inpt_all_RI_ref = np.zeros([1, 1, mat_zp_ref.shape[1] * 2, mat_zp_ref.shape[2], mat_zp_ref.shape[3],], dtype='float32')
outpt_all_RI_ref = np.zeros([1, 1, mat_zp_ref.shape[1] * 2, mat_zp_ref.shape[2], mat_zp_ref.shape[3],], dtype='float32')
inpt_all_RI_venc = np.zeros([1, 1, mat_zp_ref.shape[1] * 2, mat_zp_ref.shape[2], mat_zp_ref.shape[3],], dtype='float32')
outpt_all_RI_venc = np.zeros([1, 1, mat_zp_ref.shape[1] * 2, mat_zp_ref.shape[2], mat_zp_ref.shape[3],],dtype='float32')

# Normalizing datasets
nxx = mat_zp_ref.shape[1]
nx_crop2 = normalize_window
startx = np.floor(mat_zp_ref.shape[1] / 2 - nx_crop2 / 2).astype(int)
endx = np.floor(mat_zp_ref.shape[1] / 2 + nx_crop2 / 2).astype(int)
mat_zp_ref_crop = mat_zp_ref[:, startx:endx, startx:endx, :]
mat_zp_venc_crop = mat_zp_venc[:, startx:endx, startx:endx, :]
mat_zp_ref = mat_zp_ref / np.percentile(np.abs(mat_zp_ref_crop), 95)
mat_zp_venc = mat_zp_venc / np.percentile(np.abs(mat_zp_ref_crop), 95)


# Concatinating real and imaginary components (needed for reconstruction with real CNN)
inpt_all_RI_ref[0, 0, 0:mat_zp_ref.shape[1], :, :] = np.real(mat_zp_ref[0, :, :, :])
inpt_all_RI_ref[0, 0, mat_zp_ref.shape[1]:mat_zp_ref.shape[1] * 2, :] = np.imag(mat_zp_ref[0, :, :, :])
inpt_all_RI_venc[0, 0, 0:mat_zp_ref.shape[1], :, :] = np.real(mat_zp_venc[0, :, :, :])
inpt_all_RI_venc[0, 0, mat_zp_ref.shape[1]:mat_zp_ref.shape[1] * 2, :] = np.imag(mat_zp_venc[0, :, :, :])

# Reconstruction of reference dataset
inpt_ref_recon_v2 = netRef(torch.from_numpy(np.single(inpt_all_RI_ref)))
inpt_ref_recon_v2 = inpt_ref_recon_v2.cpu().data.numpy()

# Reconstruction of complex-difference dataset
inpt_CD_recon_v2 = netCD(torch.from_numpy(inpt_all_RI_ref - inpt_all_RI_venc))
inpt_CD_recon_v2 = inpt_CD_recon_v2.cpu().data.numpy()

# orginizing datasets for visualization
outp_netCD  = inpt_CD_recon_v2
outp_netRef_CD  = inpt_ref_recon_v2
input_ref = inpt_all_RI_ref
input_venc = inpt_all_RI_venc

input_CD = input_ref-input_venc
inpt_ref2 = (input_ref[:,:,0:np.int(input_ref.shape[2]/2),:,:]+1j*input_ref[:,:,np.int(input_ref.shape[2]/2):input_ref.shape[2],:,:])
input_CD2 = (input_CD[:,:,0:np.int(input_CD.shape[2]/2),:,:]+1j*input_CD[:,:,np.int(input_CD.shape[2]/2):input_CD.shape[2],:,:])
inpt_venc2 = (input_venc[:,:,0:np.int(input_venc.shape[2]/2),:,:]+1j*input_venc[:,:,np.int(input_venc.shape[2]/2):input_venc.shape[2],:,:])

Net_ref= outp_netRef_CD
Net_CD= outp_netCD
Net_ref2 = (Net_ref[:,:,0:np.int(Net_ref.shape[2]/2),:,:]+1j*Net_ref[:,:,np.int(Net_ref.shape[2]/2):Net_ref.shape[2],:,:])
Net_CD2 = (Net_CD[:,:,0:np.int(Net_CD.shape[2]/2),:,:]+1j*Net_CD[:,:,np.int(Net_CD.shape[2]/2):Net_CD.shape[2],:,:])

PhaseContrast_ZF = np.angle(inpt_ref2*np.conj(inpt_venc2))
Net_venc2 = -1*(Net_CD2-Net_ref2)
PhaseContrast_DL = np.angle(Net_ref2*np.conj(Net_venc2))

# visualizing Complex Difference
time_use = 26
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(13, 5))
ax[0].imshow(np.abs(input_CD2[ 0, 0, :,:, time_use]), cmap='gray', vmin= 0, vmax= .95)
ax[1].imshow(np.abs(Net_CD2[ 0, 0, :,:, time_use]), cmap='gray', vmin= 0, vmax= .95)

# visualizing reference
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(13, 5))
ax[0].imshow(np.abs(inpt_ref2[ 0, 0, :,:, time_use]), cmap='gray', vmin=0, vmax=2.5)
ax[1].imshow(np.abs(Net_ref2[ 0, 0, :,:, time_use]), cmap='gray', vmin=0, vmax=2.5)

# visualizing Phase contrast
cmin = -3.16/2
cmax = 3.16/2
fig, ax = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(13, 5))
ax[0].imshow((PhaseContrast_ZF[ 0, 0, :,:, time_use]), cmap='gray', vmin= cmin, vmax= cmax)
ax[1].imshow((PhaseContrast_DL[ 0, 0, :,:, time_use]), cmap='gray', vmin= cmin, vmax= cmax)
